#!/usr/bin/env python
# encoding: utf-8
"""
...
"""

#
#  Copyright (c) 2009 University of Dundee. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in the
#     documentation and/or other materials provided with the distribution.
#
#  THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
#  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
#  OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
#  HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
#  LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
#  OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
#  SUCH DAMAGE.

import unittest
import sys

from StringIO import StringIO
from getopt import getopt, GetoptError
from copy import deepcopy
from xsltbasic import XsltBasic
# We're using lxml's ElementTree implementation for XML manipulation due to
# its XSLT integration.
from lxml.etree import XML, XSLT, XMLSchema, XMLSchemaValidateError, Element, SubElement, ElementTree, dump, parse

# Handle Python 2.5 built-in ElementTree
#try:
#        from xml.etree.ElementTree import XML, Element, SubElement, ElementTree, dump
#except ImportError:
#        from elementtree.ElementTree import XML, Element, SubElement, ElementTree, dump


def usage(error):
    """Prints usage so that we don't have to. :)"""
    cmd = sys.argv[0]
    print """%s
Usage: %s <stylesheet.xsl> <input.xml> [output.xml]
Runs the stylesheet transform upon an XML instance document.

Options:

Examples:
  %s foo.xslt bar.xml bar_new.xml

Report bugs to ome-devel@lists.openmicroscopy.org.uk""" % (error, cmd, cmd)
    sys.exit(2)

def run_stylesheet(xslt, document):
    xslt_doc = parse(xslt)
    transform = XSLT(xslt_doc)
    return transform(document)


class Test200809(XsltBasic):

    STYLESHEET = "../specification/Xslt/2008-09-to-2009-09.xsl"

    OLD_VALIDATION = "../specification/Released-Schema/2008-09/ome.xsd"
    NEW_VALIDATION = "../specification/Released-Schema/2009-09/ome.xsd"
#    NEW_VALIDATION = "../specification/InProgress/ome.xsd"

    DOCUMENT = "../specification/InProgress/test-ome-samples/sample-2008-09.ome"

    OLD_OME_NS = "http://www.openmicroscopy.org/Schemas/OME/2008-09"

    OLD_SPW_NS = "http://www.openmicroscopy.org/Schemas/SPW/2008-09"

    OLD_AM_NS = "http://www.openmicroscopy.org/Schemas/AnalysisModule/2008-09"

    NEW_OME_NS = "http://www.openmicroscopy.org/Schemas/OME/2009-09"

    NEW_SPW_NS = "http://www.openmicroscopy.org/Schemas/SPW/2009-09"

    def setUp(self):
        self.xslt_file = open(self.STYLESHEET)
        self.instance_file = open(self.DOCUMENT)
        self.instance_document = parse(self.instance_file)
        self.result = run_stylesheet(self.xslt_file, self.instance_document)
        self.old_root = self.instance_document.getroot()
        self.new_root = self.result.getroot()

    def tearDown(self):
        self.xslt_file.close()
        self.instance_file.close()

#    def test_dummy(self):
#        print self.result


    def test_objective_settings(self):
        """
        ObjectiveRef to ObjectiveSettings transformation
        """

        self.checkElementsMappedNoCount( self.old_root, self.OLD_OME_NS, self.new_root, self.NEW_OME_NS,{'ObjectiveRef':'ObjectiveSettings'})

    def test_light_source_settings(self):
        """
        LightSourceRef to LightSettings transformation
        """
        self.checkElementsMappedNoCount(self.old_root, self.OLD_OME_NS, self.new_root, self.NEW_OME_NS, {'LightSourceRef':'LightSourceSettings'})

    def test_detector_settings(self):
        """
        DetectorRef to DetectorSettings transformation
        """
        self.checkElementsMappedNoCount(self.old_root, self.OLD_OME_NS, self.new_root, self.NEW_OME_NS, {'DetectorRef':'DetectorSettings'})

    def test_wellsample_attributes(self):
        """
        Change the WellSample attributes
        Remove Index, Rename PosX to PositionX & PosY to PositionY
        """

        self.checkElementsMapped(self.old_root, self.OLD_SPW_NS, self.new_root, self.NEW_SPW_NS, {'WellSample':'WellSample'})
        self.checkAttributesExcluded(self.new_root, self.NEW_SPW_NS, 'WellSample', ['Index'])
        self.checkAttributesExcluded(self.new_root, self.NEW_SPW_NS, 'WellSample', ['PosX'])
        self.checkAttributesExcluded(self.new_root, self.NEW_SPW_NS, 'WellSample', ['PosY'])

        # count old PosX and PosY attributes in old WellSample
        xpath = './/{%s}WellSample' % self.OLD_SPW_NS
        old_elements = self.old_root.findall(xpath)
        reference_count_posx = 0
        reference_count_posy = 0
        for element in old_elements:
            if ('PosX' in element.keys()):
                reference_count_posx+=1
            if ('PosY' in element.keys()):
                reference_count_posy+=1

        # count old PosX and PosY attributes in new WellSample
        xpath = './/{%s}WellSample' % self.NEW_SPW_NS
        new_wellsample_elements = self.new_root.findall(xpath)
        new_count_posx = 0
        new_count_posy = 0
        for element in new_wellsample_elements:
            if ('PositionX' in element.keys()):
                new_count_posx+=1
            if ('PositionY' in element.keys()):
                new_count_posy+=1

        # compare old PosX attributes to new PositionX attributes
        self.assertEquals(reference_count_posx, new_count_posx)
        # compare old PosY attributes to new PositionY attributes
        self.assertEquals(reference_count_posy, new_count_posy)

    def test_description_no_lang(self):
        """
        Change OME:Description to be a local simple type in each
        element it is used called Description based on xsd:string
        """
        # find old OME:Description
        xpath = './/{%s}Description' % self.OLD_OME_NS
        reference_count_all = len(self.old_root.findall(xpath))
        self.assertTrue(reference_count_all > 0)
        # find old OME:Description only in SPW:Screen
        xpath = './/{%s}Screen/{%s}Description' % (self.OLD_SPW_NS, self.OLD_OME_NS)
        reference_screen_descripton_elements = self.old_root.findall(xpath)
        # find old OME:Description only in AM:AnalysisModule
        xpath = './/{%s}AnalysisModule/{%s}Description' % (self.OLD_AM_NS, self.OLD_OME_NS)
        reference_analysismodule_descripton_elements = self.old_root.findall(xpath)
        # find old OME:Description only in AM:LookupTable
        xpath = './/{%s}LookupTable/{%s}Description' % (self.OLD_AM_NS, self.OLD_OME_NS)
        reference_lookuptable_descripton_elements = self.old_root.findall(xpath)
        # find old OME:Description only in AM:FormalInput
        xpath = './/{%s}FormalInput/{%s}Description' % (self.OLD_AM_NS, self.OLD_OME_NS)
        reference_formalinput_descripton_elements = self.old_root.findall(xpath)
        # find old OME:Description only in AM:FormalOutput
        xpath = './/{%s}FormalOutput/{%s}Description' % (self.OLD_AM_NS, self.OLD_OME_NS)
        reference_formaloutput_descripton_elements = self.old_root.findall(xpath)
        # find old OME:Description only in AM:Category
        xpath = './/{%s}Category/{%s}Description' % (self.OLD_AM_NS, self.OLD_OME_NS)
        reference_category_descripton_elements = self.old_root.findall(xpath)


        # find new OME:Description
        xpath = './/{%s}Description' % self.NEW_OME_NS
        new_ome_descripton_elements = self.new_root.findall(xpath)
        # find new SPW:Description only in SPW:Screen
        xpath = './/{%s}Screen/{%s}Description' % (self.NEW_SPW_NS, self.NEW_SPW_NS)
        new_screen_descripton_elements = self.new_root.findall(xpath)
        # old attributes must not be present in new OME:Description nodes
        xpath = './/{%s}Description' % self.NEW_OME_NS
        new_elements = self.new_root.findall(xpath)
        for element in new_elements:
            self.assertFalse('lang' in element.keys())
        # old attributes must not be present in new SPW:Description nodes
        xpath = './/{%s}Description' % self.NEW_SPW_NS
        new_elements = self.new_root.findall(xpath)
        for element in new_elements:
            self.assertFalse('lang' in element.keys())
        # all old OME - AM must equal new OME + new SPW
        self.assertEquals(
            reference_count_all
                - len(reference_analysismodule_descripton_elements)
                - len(reference_lookuptable_descripton_elements)
                - len(reference_formaloutput_descripton_elements)
                - len(reference_category_descripton_elements)
                - len(reference_formalinput_descripton_elements),
            len(new_ome_descripton_elements)
                + len(new_screen_descripton_elements))
        # old OME in screen must equal new OME in screen
        self.assertEquals(len(reference_screen_descripton_elements),len(new_screen_descripton_elements))

    def test_plate_description(self):
        """
        Change the Description attribute to be a local simple type in the Plate
        element to be a Description element based on xsd:string
        """
        # find old SPW:Plate
        xpath = './/{%s}Plate' % self.OLD_SPW_NS
        reference_count_all = len(self.old_root.findall(xpath))
        self.assertTrue(reference_count_all > 0)
        # count old Description attributes in Plate
        old_elements = self.old_root.findall(xpath)
        reference_count_attribute = 0
        for element in old_elements:
            if ('Description' in element.keys()):
                reference_count_attribute+=1

        # find new SPW:Plate
        xpath = './/{%s}Plate' % self.NEW_SPW_NS
        new_plate_count = len(self.new_root.findall(xpath))
        # old plate equals new plate
        self.assertEquals(reference_count_all, new_plate_count)
        # find new SPW:Description only in SPW:Plate
        xpath = './/{%s}Plate/{%s}Description' % (self.NEW_SPW_NS, self.NEW_SPW_NS)
        new_plate_descripton_count = len(self.new_root.findall(xpath))
        # compare old Description attributes to new Description elements in Plate
        self.assertEquals(reference_count_attribute, new_plate_descripton_count)

        # old attributes must not be present in new SPW:Plate nodes
        xpath = './/{%s}Plate' % self.NEW_SPW_NS
        new_elements = self.new_root.findall(xpath)
        for element in new_elements:
            self.assertFalse('Description' in element.keys())

    def test_dataset_no_locked(self):
        """
        Remove Locked from Dataset
        """
        self.checkElementsMapped(self.old_root, self.OLD_OME_NS, self.new_root, self.NEW_OME_NS, {'Dataset':'Dataset'})
        self.checkAttributesExcluded(self.new_root, self.NEW_OME_NS, 'Dataset', ['Locked'])

    def test_objective_elements_to_attributes(self):
        """
        Objective transformation. See stylesheet for details
        """
        """
        fromElements = self.getAllElements(self.old_root, self.OLD_OME_NS, ['Objective']);

        toAttributes = self.getAllElements(self.new_root, self.NEW_OME_NS, ['Objective']);
        for i, toAttributeElement in enumerate(toAttributes):
            self.compareElementsWithAttributes(fromElements[i], self.OLD_OME_NS, toAttributeElement);
        """

    def test_experimenter_elements_to_attributes(self):
        """
        Experimenter transformation. See stylesheet for details
        """

        fromElements = self.getAllElements(self.old_root, self.OLD_OME_NS, ['Experimenter']);
        toAttributes = self.getAllElements(self.new_root, self.NEW_OME_NS, ['Experimenter']);
        self.assertEqual(len(fromElements), len(toAttributes));
        if (len(fromElements) == 0):
            return;
        inverseExceptionList = self.getChildElements(fromElements[0], self.OLD_OME_NS, ['GroupRef']);
        for i, toAttributeElement in enumerate(toAttributes):
            self.compareElementsWithAttributes(fromElements[i], self.OLD_OME_NS, toAttributeElement, ['GroupRef'], {'OMEName':'UserName'});
            self.compareElements(fromElements[i], self.OLD_OME_NS, toAttributeElement, self.NEW_OME_NS, inverseExceptionList)

    def test_filter_set_elements_to_attributes(self):
        """
        FilterSet transformation. See stylesheet for details
        """
        fromAttributes = self.getAllElements(self.old_root, self.OLD_OME_NS, ['FilterSet'])
        toElements = self.getAllElements(self.new_root, self.NEW_OME_NS, ['FilterSet'])
        self.assertEqual(len(fromAttributes), len(toElements))
        if(len(fromAttributes)==0):
            return;
        for i, fromAttributeElement in enumerate(fromAttributes):
            self.compareElementsWithAttributesFromMap(toElements[i], fromAttributeElement, {'EmFilterRef':'EmissionFilterRef', 'ExFilterRef':'ExcitationFilterRef'})
            self.compareAttributes(fromAttributeElement, toElements[i], ["EmFilterRef","ExFilterRef","DichroicRef"])
        self.checkAttributesExcluded(self.new_root, self.NEW_OME_NS, 'FilterSet', ["EmFilterRef","ExFilterRef","DichroicRef"])

    def test_tiffdata_rename_numplanes_to_planecount(self):
        """
        TiffData transformation. See stylesheet for details
        """
        oldElements = self.getAllElements(self.old_root, self.OLD_OME_NS, ['TiffData']);
        newElements = self.getAllElements(self.new_root, self.NEW_OME_NS, ['TiffData']);

        self.assertEqual(len(oldElements), len(newElements));
        for i, oldElement in enumerate(oldElements):
            self.compareAttributes(oldElement, newElements[i], None, {"NumPlanes":"PlaneCount"});

    def test_experiment_copy_microbeammanipulation_from_image(self):
        """
        Experiment transformation. See stylesheet for details
        """
        oldElements = self.getAllElements(self.old_root, self.OLD_OME_NS, ["Experiment"]);
        newElements = self.getAllElements(self.new_root, self.NEW_OME_NS, ["Experiment"]);

        for i, newElement in enumerate(newElements):
            id = newElement.get('ID');
            elementWithRefs = self.findElementByID(self.old_root,self.OLD_OME_NS,'Experiment',id);
            oldElementWithoutMBRef = self.replaceRefsWithElements(self.old_root, self.OLD_OME_NS, elementWithRefs, ["MicrobeamManipulation"]);
#            print 'DUMP(oldElementWithoutChannel)';
#            dump(oldElementWithoutMBRef);
#            print 'DUMP(newElement)';
#            dump(newElement);
            self.compareGraphs(
                oldElementWithoutMBRef,
                newElement,
                ['{http://www.w3.org/XML/1998/namespace}lang', 'ID'], #ignoreAttributes
                None, #renameAttributes
                {"LightSourceRef":"LightSourceSettings"}); #renameElements

    def test_detector_type(self):
        """
        Checks that the Detector types have been correctly transformed.
        """
        fromElements = self.getAllElements(self.old_root, self.OLD_OME_NS,
                                           ['Detector'])
        toElements = self.getAllElements(self.new_root, self.NEW_OME_NS,
                                         ['Detector'])
        self.assertEqual(len(fromElements), len(toElements))
#        print "Element count: %d" % len(toElements)
#        for element in toElements:
#            print dump(element)

    """
    Follow a collection of tests for transformation related to ROI elements.
    """

    def test_shape_transformation(self):
        """
        Checks that the shape types have been correctly transformed.
        """

    def test_validate_input(self):
        """
            Checks that the input validates
        """
        # load the old schema
        try:
            schema = XMLSchema(parse(self.OLD_VALIDATION))
        except:
            # chosen schema failed to laod
            print "Validator Internal error: XSD schema file could not be found [OLD]"

        # validating the document tree against the loaded schema
        # according to the docs this should not throw an exception - but it does!
        try:
            schema.validate(self.old_root)
            isXsdValid = True
            for err in schema.error_log:
                isXsdValid = False
                print "File: %s Line: (%s, %s) Type: %s (%s) - %s\n" % ("Input", err.line, None, "XSD", None, err.message)

        except XMLSchemaValidateError:
            isXsdValid = False
            print "File: %s Line: (%s, %s) Type: %s (%s) - %s\n" % ("Input", None, None, "XML", None, "Processing the XML data has generated an unspecified error in the XML sub-system. This is usually a result of an incorrect top level block. Please check the OME block is well-formed and that the schemaLocation is specified correctly. This may also be caused by a missing namespace prefix or incorrect xmlns attribute.")

        self.assertTrue(isXsdValid);

    def test_validate_output(self):
        """
            Checks that the output validates
        """
        # load the new schema
        try:
            schema = XMLSchema(parse(self.NEW_VALIDATION))
        except:
            # chosen schema failed to laod
            print "Validator Internal error: XSD schema file could not be found [NEW]"

        # validating the document tree against the loaded schema
        # according to the docs this should not throw an exception - but it does!
        try:
            schema.validate(self.new_root)
            isXsdValid = True
            for err in schema.error_log:
                isXsdValid = False
                print "File: %s Line: (%s, %s) Type: %s (%s) - %s\n" % ("Output", err.line, None, "XSD", None, err.message)

        except XMLSchemaValidateError:
            isXsdValid = False
            print "File: %s Line: (%s, %s) Type: %s (%s) - %s\n" % ("Output", None, None, "XML", None, "Processing the XML data has generated an unspecified error in the XML sub-system. This is usually a result of an incorrect top level block. Please check the OME block is well-formed and that the schemaLocation is specified correctly. This may also be caused by a missing namespace prefix or incorrect xmlns attribute.")
            print "File is invalid"

        self.assertTrue(isXsdValid);


if __name__ == '__main__':
    unittest.main()
"""
    try:
        options, args = getopt(sys.argv[1:], "")
    except GetoptError, (msg, opt):
        usage(msg)

    for option, argument in options:
        pass

    xslt_filename = "2008-09.xsl"
    xslt = open(xslt_filename)
    in_filename = "tmp/sample-lsm-nobindata.ome"
    in_file = open(in_filename)
    try:
        print "Running XSLT %s on %s..." % (xslt_filename, in_filename)
        run_stylesheet(xslt, in_file, None)
    finally:
        xslt.close()
        in_file.close()
"""
